{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from string import punctuation\n",
    "from collections import Counter\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "torch.use_deterministic_algorithms(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Instructions:\n",
    "\n",
    "- You need to download the dataset from here: https://ai.stanford.edu/~amaas/data/sentiment/\n",
    "- Then, you need to unzip the downloaded zipped file in the present working directory\n",
    "- That should result in a new folder named aclImdb under the present working directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [00:00<00:00, 46053.84it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12500/12500 [00:00<00:00, 46676.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of reviews : 25000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# read sentiments and reviews data from the text files\n",
    "review_list = []\n",
    "label_list = []\n",
    "for label in ['pos', 'neg']:\n",
    "    for fname in tqdm(os.listdir(f'./aclImdb/train/{label}/')):\n",
    "        if 'txt' not in fname:\n",
    "            continue\n",
    "        with open(os.path.join(f'./aclImdb/train/{label}/', fname), encoding=\"utf8\") as f:\n",
    "            review_list += [f.read()]\n",
    "            label_list += [label]\n",
    "print ('Number of reviews :', len(review_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25000/25000 [00:01<00:00, 13191.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('the', 334691), ('and', 162228), ('a', 161940), ('of', 145326), ('to', 135042), ('is', 106855), ('in', 93028), ('it', 77099), ('i', 75719), ('this', 75190)]\n"
     ]
    }
   ],
   "source": [
    "# pre-processing review text\n",
    "review_list = [review.lower() for review in review_list]\n",
    "review_list = [''.join([letter for letter in review if letter not in punctuation]) for review in tqdm(review_list)]\n",
    "\n",
    "# accumulate all review texts together\n",
    "reviews_blob = ' '.join(review_list)\n",
    "\n",
    "# generate list of all words of all reviews\n",
    "review_words = reviews_blob.split()\n",
    "\n",
    "# get the word counts\n",
    "count_words = Counter(review_words)\n",
    "\n",
    "# sort words as per counts (decreasing order)\n",
    "total_review_words = len(review_words)\n",
    "sorted_review_words = count_words.most_common(total_review_words)\n",
    "\n",
    "print(sorted_review_words[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('the', 1), ('and', 2), ('a', 3), ('of', 4), ('to', 5), ('is', 6), ('in', 7), ('it', 8), ('i', 9), ('this', 10)]\n"
     ]
    }
   ],
   "source": [
    "# create word to integer (token) dictionary in order to encode text as numbers\n",
    "vocab_to_token = {word:idx+1 for idx, (word, count) in enumerate(sorted_review_words)}\n",
    "print(list(vocab_to_token.items())[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "one of disneys best films that i can enjoy watching often you may easily guess the outcome but who cares its just plain fun escape for 1 hour fortytwo minutes and after all wasnt movies meant to get away from reality for just a short time anyway the cast sparkles with delight magictrain\n",
      "\n",
      "[28, 4, 4534, 114, 94, 11, 9, 68, 343, 146, 388, 22, 193, 683, 477, 1, 3709, 18, 36, 2264, 29, 40, 1017, 249, 1072, 15, 467, 558, 36644, 232, 2, 100, 31, 267, 92, 945, 5, 75, 241, 35, 638, 15, 40, 3, 347, 61, 577, 1, 176, 17217, 16, 2991, 43770]\n"
     ]
    }
   ],
   "source": [
    "reviews_tokenized = []\n",
    "for review in review_list:\n",
    "    word_to_token = [vocab_to_token[word] for word in review.split()]\n",
    "    reviews_tokenized.append(word_to_token)\n",
    "print(review_list[0])\n",
    "print()\n",
    "print (reviews_tokenized[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encode sentiments as 0 or 1\n",
    "encoded_label_list = [1 if label =='pos' else 0 for label in label_list]\n",
    "\n",
    "reviews_len = [len(review) for review in reviews_tokenized]\n",
    "\n",
    "reviews_tokenized = [reviews_tokenized[i] for i, l in enumerate(reviews_len) if l>0 ]\n",
    "encoded_label_list = np.array([encoded_label_list[i] for i, l in enumerate(reviews_len) if l> 0 ], dtype='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjoAAAGdCAYAAAAbudkLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuXklEQVR4nO3df3BUVZ7//1cAuwGlEyAknYwBwg/5IeGnY2xHUJZsGkipGdldBBR0IgxOcJQgP6KIAbY2DBQ4zIiwlmLcGhRkS1DBRUIAI9IgRAIGJCWYGF3TYQYkzS8DIff7h5/cr72AGO0k5Ph8VN2q3Hve9/Q5h9j98vbtTphlWZYAAAAM1KyxBwAAAFBfCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGO1aOwBNKaamhp9/fXXatOmjcLCwhp7OAAA4EewLEunTp1SbGysmjX74Ws2v+ig8/XXXysuLq6xhwEAAH6CL7/8UjfeeOMP1vyig06bNm0kfbdQLperkUcDAAB+jEAgoLi4OPt1/If8ooNO7dtVLpeLoAMAQBPzY2474WZkAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGO1qOsJ+fn5WrRokQoKClReXq5169YpNTXVbr/Sn0xfuHChpk+fLknq3Lmzvvjii6D27OxszZo1y94/cOCA0tPTtWfPHnXo0EGPPfaYZsyYEXTO2rVr9cwzz6i0tFTdu3fXn/70J40cObKuU6oXnWdtbOwh1FnpgpTGHgIAACFV5ys6Z86cUb9+/bRs2bLLtpeXlwdtK1euVFhYmEaNGhVUN2/evKC6xx57zG4LBAJKTk5Wp06dVFBQoEWLFikrK0svvviiXbNz506NGTNGaWlp2rdvn1JTU5WamqqioqK6TgkAABiqzld0RowYoREjRlyx3e12B+2/9dZbGjp0qLp06RJ0vE2bNpfU1lq1apXOnz+vlStXyuFw6Oabb1ZhYaGWLFmiSZMmSZKWLl2q4cOH21eJ5s+fr9zcXD3//PNasWJFXacFAAAMVK/36FRUVGjjxo1KS0u7pG3BggVq3769BgwYoEWLFqm6utpu8/l8GjJkiBwOh33M6/WquLhY33zzjV2TlJQU1KfX65XP57vieKqqqhQIBII2AABgrjpf0amLV199VW3atNF9990XdPyPf/yjBg4cqHbt2mnnzp3KzMxUeXm5lixZIkny+/2Kj48POic6Otpua9u2rfx+v33s+zV+v/+K48nOztbcuXNDMTUAANAE1GvQWblypcaNG6eWLVsGHc/IyLB/7tu3rxwOh37/+98rOztbTqez3saTmZkZ9NiBQEBxcXH19ngAAKBx1VvQ+eCDD1RcXKw1a9ZctTYxMVHV1dUqLS1Vjx495Ha7VVFREVRTu197X8+Vaq50348kOZ3Oeg1SAADg2lJv9+i8/PLLGjRokPr163fV2sLCQjVr1kxRUVGSJI/Ho/z8fF24cMGuyc3NVY8ePdS2bVu7Ji8vL6if3NxceTyeEM4CAAA0ZXUOOqdPn1ZhYaEKCwslSSUlJSosLFRZWZldEwgEtHbtWj3yyCOXnO/z+fTnP/9Z+/fv1+eff65Vq1Zp6tSpeuCBB+wQM3bsWDkcDqWlpengwYNas2aNli5dGvS20+OPP65NmzZp8eLFOnz4sLKysrR3715NmTKlrlMCAACGqvNbV3v37tXQoUPt/drwMWHCBOXk5EiSVq9eLcuyNGbMmEvOdzqdWr16tbKyslRVVaX4+HhNnTo1KMSEh4dr8+bNSk9P16BBgxQZGak5c+bYHy2XpNtvv12vvfaaZs+eraeeekrdu3fX+vXr1adPn7pOCQAAGCrMsiyrsQfRWAKBgMLDw1VZWSmXyxXSvvlmZAAA6kddXr/5W1cAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAY9U56OTn5+vuu+9WbGyswsLCtH79+qD2hx56SGFhYUHb8OHDg2pOnDihcePGyeVyKSIiQmlpaTp9+nRQzYEDBzR48GC1bNlScXFxWrhw4SVjWbt2rXr27KmWLVsqISFB7777bl2nAwAADFbnoHPmzBn169dPy5Ytu2LN8OHDVV5ebm+vv/56UPu4ceN08OBB5ebmasOGDcrPz9ekSZPs9kAgoOTkZHXq1EkFBQVatGiRsrKy9OKLL9o1O3fu1JgxY5SWlqZ9+/YpNTVVqampKioqquuUAACAocIsy7J+8slhYVq3bp1SU1PtYw899JBOnjx5yZWeWp9++ql69+6tPXv26JZbbpEkbdq0SSNHjtRXX32l2NhYLV++XE8//bT8fr8cDockadasWVq/fr0OHz4sSRo9erTOnDmjDRs22H3fdttt6t+/v1asWPGjxh8IBBQeHq7Kykq5XK6fsAJX1nnWxpD21xBKF6Q09hAAALiqurx+18s9Otu3b1dUVJR69OihRx99VMePH7fbfD6fIiIi7JAjSUlJSWrWrJl2795t1wwZMsQOOZLk9XpVXFysb775xq5JSkoKelyv1yufz3fFcVVVVSkQCARtAADAXCEPOsOHD9d//dd/KS8vT3/605/0/vvva8SIEbp48aIkye/3KyoqKuicFi1aqF27dvL7/XZNdHR0UE3t/tVqatsvJzs7W+Hh4fYWFxf38yYLAACuaS1C3eH9999v/5yQkKC+ffuqa9eu2r59u4YNGxbqh6uTzMxMZWRk2PuBQICwAwCAwer94+VdunRRZGSkjhw5Iklyu906duxYUE11dbVOnDght9tt11RUVATV1O5fraa2/XKcTqdcLlfQBgAAzFXvQeerr77S8ePHFRMTI0nyeDw6efKkCgoK7JqtW7eqpqZGiYmJdk1+fr4uXLhg1+Tm5qpHjx5q27atXZOXlxf0WLm5ufJ4PPU9JQAA0ETUOeicPn1ahYWFKiwslCSVlJSosLBQZWVlOn36tKZPn65du3aptLRUeXl5uvfee9WtWzd5vV5JUq9evTR8+HBNnDhRH330kT788ENNmTJF999/v2JjYyVJY8eOlcPhUFpamg4ePKg1a9Zo6dKlQW87Pf7449q0aZMWL16sw4cPKysrS3v37tWUKVNCsCwAAMAEdQ46e/fu1YABAzRgwABJUkZGhgYMGKA5c+aoefPmOnDggO655x7ddNNNSktL06BBg/TBBx/I6XTafaxatUo9e/bUsGHDNHLkSN1xxx1B35ETHh6uzZs3q6SkRIMGDdK0adM0Z86coO/auf322/Xaa6/pxRdfVL9+/fTf//3fWr9+vfr06fNz1gMAABjkZ32PTlPH9+gE43t0AABNQaN/jw4AAMC1gKADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYKw6B538/Hzdfffdio2NVVhYmNavX2+3XbhwQTNnzlRCQoKuv/56xcbGavz48fr666+D+ujcubPCwsKCtgULFgTVHDhwQIMHD1bLli0VFxenhQsXXjKWtWvXqmfPnmrZsqUSEhL07rvv1nU6AADAYHUOOmfOnFG/fv20bNmyS9rOnj2rjz/+WM8884w+/vhjvfnmmyouLtY999xzSe28efNUXl5ub4899pjdFggElJycrE6dOqmgoECLFi1SVlaWXnzxRbtm586dGjNmjNLS0rRv3z6lpqYqNTVVRUVFdZ0SAAAwVIu6njBixAiNGDHism3h4eHKzc0NOvb888/r1ltvVVlZmTp27Ggfb9Omjdxu92X7WbVqlc6fP6+VK1fK4XDo5ptvVmFhoZYsWaJJkyZJkpYuXarhw4dr+vTpkqT58+crNzdXzz//vFasWFHXaQEAAAPV+z06lZWVCgsLU0RERNDxBQsWqH379howYIAWLVqk6upqu83n82nIkCFyOBz2Ma/Xq+LiYn3zzTd2TVJSUlCfXq9XPp/vimOpqqpSIBAI2gAAgLnqfEWnLr799lvNnDlTY8aMkcvlso//8Y9/1MCBA9WuXTvt3LlTmZmZKi8v15IlSyRJfr9f8fHxQX1FR0fbbW3btpXf77ePfb/G7/dfcTzZ2dmaO3duqKYHAACucfUWdC5cuKB/+7d/k2VZWr58eVBbRkaG/XPfvn3lcDj0+9//XtnZ2XI6nfU1JGVmZgY9diAQUFxcXL09HgAAaFz1EnRqQ84XX3yhrVu3Bl3NuZzExERVV1ertLRUPXr0kNvtVkVFRVBN7X7tfT1XqrnSfT+S5HQ66zVIAQCAa0vI79GpDTmfffaZtmzZovbt21/1nMLCQjVr1kxRUVGSJI/Ho/z8fF24cMGuyc3NVY8ePdS2bVu7Ji8vL6if3NxceTyeEM4GAAA0ZXW+onP69GkdOXLE3i8pKVFhYaHatWunmJgY/cu//Is+/vhjbdiwQRcvXrTvmWnXrp0cDod8Pp92796toUOHqk2bNvL5fJo6daoeeOABO8SMHTtWc+fOVVpammbOnKmioiItXbpUzz33nP24jz/+uO68804tXrxYKSkpWr16tfbu3Rv0EXQAAPDLFmZZllWXE7Zv366hQ4decnzChAnKysq65CbiWtu2bdNdd92ljz/+WH/4wx90+PBhVVVVKT4+Xg8++KAyMjKC3lY6cOCA0tPTtWfPHkVGRuqxxx7TzJkzg/pcu3atZs+erdLSUnXv3l0LFy7UyJEjf/RcAoGAwsPDVVlZedW31+qq86yNIe2vIZQuSGnsIQAAcFV1ef2uc9AxCUEnGEEHANAU1OX1m791BQAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGCsOged/Px83X333YqNjVVYWJjWr18f1G5ZlubMmaOYmBi1atVKSUlJ+uyzz4JqTpw4oXHjxsnlcikiIkJpaWk6ffp0UM2BAwc0ePBgtWzZUnFxcVq4cOElY1m7dq169uypli1bKiEhQe+++25dpwMAAAxW56Bz5swZ9evXT8uWLbts+8KFC/WXv/xFK1as0O7du3X99dfL6/Xq22+/tWvGjRungwcPKjc3Vxs2bFB+fr4mTZpktwcCASUnJ6tTp04qKCjQokWLlJWVpRdffNGu2blzp8aMGaO0tDTt27dPqampSk1NVVFRUV2nBAAADBVmWZb1k08OC9O6deuUmpoq6burObGxsZo2bZqefPJJSVJlZaWio6OVk5Oj+++/X59++ql69+6tPXv26JZbbpEkbdq0SSNHjtRXX32l2NhYLV++XE8//bT8fr8cDockadasWVq/fr0OHz4sSRo9erTOnDmjDRs22OO57bbb1L9/f61YseJHjT8QCCg8PFyVlZVyuVw/dRkuq/OsjSHtryGULkhp7CEAAHBVdXn9Duk9OiUlJfL7/UpKSrKPhYeHKzExUT6fT5Lk8/kUERFhhxxJSkpKUrNmzbR79267ZsiQIXbIkSSv16vi4mJ98803ds33H6e2pvZxLqeqqkqBQCBoAwAA5gpp0PH7/ZKk6OjooOPR0dF2m9/vV1RUVFB7ixYt1K5du6Cay/Xx/ce4Uk1t++VkZ2crPDzc3uLi4uo6RQAA0IT8oj51lZmZqcrKSnv78ssvG3tIAACgHoU06LjdbklSRUVF0PGKigq7ze1269ixY0Ht1dXVOnHiRFDN5fr4/mNcqaa2/XKcTqdcLlfQBgAAzBXSoBMfHy+32628vDz7WCAQ0O7du+XxeCRJHo9HJ0+eVEFBgV2zdetW1dTUKDEx0a7Jz8/XhQsX7Jrc3Fz16NFDbdu2tWu+/zi1NbWPAwAAUOegc/r0aRUWFqqwsFDSdzcgFxYWqqysTGFhYXriiSf07//+73r77bf1ySefaPz48YqNjbU/mdWrVy8NHz5cEydO1EcffaQPP/xQU6ZM0f3336/Y2FhJ0tixY+VwOJSWlqaDBw9qzZo1Wrp0qTIyMuxxPP7449q0aZMWL16sw4cPKysrS3v37tWUKVN+/qoAAAAjtKjrCXv37tXQoUPt/drwMWHCBOXk5GjGjBk6c+aMJk2apJMnT+qOO+7Qpk2b1LJlS/ucVatWacqUKRo2bJiaNWumUaNG6S9/+YvdHh4ers2bNys9PV2DBg1SZGSk5syZE/RdO7fffrtee+01zZ49W0899ZS6d++u9evXq0+fPj9pIQAAgHl+1vfoNHV8j04wvkcHANAUNNr36AAAAFxLCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxgp50OncubPCwsIu2dLT0yVJd9111yVtkydPDuqjrKxMKSkpat26taKiojR9+nRVV1cH1Wzfvl0DBw6U0+lUt27dlJOTE+qpAACAJq5FqDvcs2ePLl68aO8XFRXpn//5n/Wv//qv9rGJEydq3rx59n7r1q3tny9evKiUlBS53W7t3LlT5eXlGj9+vK677jr9x3/8hySppKREKSkpmjx5slatWqW8vDw98sgjiomJkdfrDfWUAABAExXyoNOhQ4eg/QULFqhr166688477WOtW7eW2+2+7PmbN2/WoUOHtGXLFkVHR6t///6aP3++Zs6cqaysLDkcDq1YsULx8fFavHixJKlXr17asWOHnnvuOYIOAACw1es9OufPn9ff/vY3/e53v1NYWJh9fNWqVYqMjFSfPn2UmZmps2fP2m0+n08JCQmKjo62j3m9XgUCAR08eNCuSUpKCnosr9crn8/3g+OpqqpSIBAI2gAAgLlCfkXn+9avX6+TJ0/qoYceso+NHTtWnTp1UmxsrA4cOKCZM2equLhYb775piTJ7/cHhRxJ9r7f7//BmkAgoHPnzqlVq1aXHU92drbmzp0bqukBAIBrXL0GnZdfflkjRoxQbGysfWzSpEn2zwkJCYqJidGwYcN09OhRde3atT6Ho8zMTGVkZNj7gUBAcXFx9fqYAACg8dRb0Pniiy+0ZcsW+0rNlSQmJkqSjhw5oq5du8rtduujjz4KqqmoqJAk+74et9ttH/t+jcvluuLVHElyOp1yOp11ngsAAGia6u0enVdeeUVRUVFKSUn5wbrCwkJJUkxMjCTJ4/Hok08+0bFjx+ya3NxcuVwu9e7d267Jy8sL6ic3N1cejyeEMwAAAE1dvQSdmpoavfLKK5owYYJatPj/LxodPXpU8+fPV0FBgUpLS/X2229r/PjxGjJkiPr27StJSk5OVu/evfXggw9q//79eu+99zR79mylp6fbV2MmT56szz//XDNmzNDhw4f1wgsv6I033tDUqVPrYzoAAKCJqpegs2XLFpWVlel3v/td0HGHw6EtW7YoOTlZPXv21LRp0zRq1Ci98847dk3z5s21YcMGNW/eXB6PRw888IDGjx8f9L078fHx2rhxo3Jzc9WvXz8tXrxYL730Eh8tBwAAQcIsy7IaexCNJRAIKDw8XJWVlXK5XCHtu/OsjSHtryGULvjhtxkBALgW1OX1m791BQAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYLRp7ALh2dJ61sbGHUGelC1IaewgAgGsYV3QAAICxQh50srKyFBYWFrT17NnTbv/222+Vnp6u9u3b64YbbtCoUaNUUVER1EdZWZlSUlLUunVrRUVFafr06aqurg6q2b59uwYOHCin06lu3bopJycn1FMBAABNXL1c0bn55ptVXl5ubzt27LDbpk6dqnfeeUdr167V+++/r6+//lr33Xef3X7x4kWlpKTo/Pnz2rlzp1599VXl5ORozpw5dk1JSYlSUlI0dOhQFRYW6oknntAjjzyi9957rz6mAwAAmqh6uUenRYsWcrvdlxyvrKzUyy+/rNdee03/9E//JEl65ZVX1KtXL+3atUu33XabNm/erEOHDmnLli2Kjo5W//79NX/+fM2cOVNZWVlyOBxasWKF4uPjtXjxYklSr169tGPHDj333HPyer31MSUAANAE1csVnc8++0yxsbHq0qWLxo0bp7KyMklSQUGBLly4oKSkJLu2Z8+e6tixo3w+nyTJ5/MpISFB0dHRdo3X61UgENDBgwftmu/3UVtT28eVVFVVKRAIBG0AAMBcIQ86iYmJysnJ0aZNm7R8+XKVlJRo8ODBOnXqlPx+vxwOhyIiIoLOiY6Olt/vlyT5/f6gkFPbXtv2QzWBQEDnzp274tiys7MVHh5ub3FxcT93ugAA4BoW8reuRowYYf/ct29fJSYmqlOnTnrjjTfUqlWrUD9cnWRmZiojI8PeDwQChB0AAAxW7x8vj4iI0E033aQjR47I7Xbr/PnzOnnyZFBNRUWFfU+P2+2+5FNYtftXq3G5XD8YppxOp1wuV9AGAADMVe9B5/Tp0zp69KhiYmI0aNAgXXfddcrLy7Pbi4uLVVZWJo/HI0nyeDz65JNPdOzYMbsmNzdXLpdLvXv3tmu+30dtTW0fAAAAUj0EnSeffFLvv/++SktLtXPnTv32t79V8+bNNWbMGIWHhystLU0ZGRnatm2bCgoK9PDDD8vj8ei2226TJCUnJ6t379568MEHtX//fr333nuaPXu20tPT5XQ6JUmTJ0/W559/rhkzZujw4cN64YUX9MYbb2jq1Kmhng4AAGjCQn6PzldffaUxY8bo+PHj6tChg+644w7t2rVLHTp0kCQ999xzatasmUaNGqWqqip5vV698MIL9vnNmzfXhg0b9Oijj8rj8ej666/XhAkTNG/ePLsmPj5eGzdu1NSpU7V06VLdeOONeumll/hoOQAACBJmWZbV2INoLIFAQOHh4aqsrAz5/TpN8e9GNUX8rSsA+OWpy+s3f+sKAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLEIOgAAwFghDzrZ2dn69a9/rTZt2igqKkqpqakqLi4OqrnrrrsUFhYWtE2ePDmopqysTCkpKWrdurWioqI0ffp0VVdXB9Vs375dAwcOlNPpVLdu3ZSTkxPq6QAAgCYs5EHn/fffV3p6unbt2qXc3FxduHBBycnJOnPmTFDdxIkTVV5ebm8LFy602y5evKiUlBSdP39eO3fu1KuvvqqcnBzNmTPHrikpKVFKSoqGDh2qwsJCPfHEE3rkkUf03nvvhXpKAACgiWoR6g43bdoUtJ+Tk6OoqCgVFBRoyJAh9vHWrVvL7XZfto/Nmzfr0KFD2rJli6Kjo9W/f3/Nnz9fM2fOVFZWlhwOh1asWKH4+HgtXrxYktSrVy/t2LFDzz33nLxeb6inBQAAmqB6v0ensrJSktSuXbug46tWrVJkZKT69OmjzMxMnT171m7z+XxKSEhQdHS0fczr9SoQCOjgwYN2TVJSUlCfXq9XPp+vvqYCAACamJBf0fm+mpoaPfHEE/rNb36jPn362MfHjh2rTp06KTY2VgcOHNDMmTNVXFysN998U5Lk9/uDQo4ke9/v9/9gTSAQ0Llz59SqVatLxlNVVaWqqip7PxAIhGaiAADgmlSvQSc9PV1FRUXasWNH0PFJkybZPyckJCgmJkbDhg3T0aNH1bVr13obT3Z2tubOnVtv/QMAgGtLvb11NWXKFG3YsEHbtm3TjTfe+IO1iYmJkqQjR45IktxutyoqKoJqavdr7+u5Uo3L5brs1RxJyszMVGVlpb19+eWXdZ8YAABoMkIedCzL0pQpU7Ru3Tpt3bpV8fHxVz2nsLBQkhQTEyNJ8ng8+uSTT3Ts2DG7Jjc3Vy6XS71797Zr8vLygvrJzc2Vx+O54uM4nU65XK6gDQAAmCvkQSc9PV1/+9vf9Nprr6lNmzby+/3y+/06d+6cJOno0aOaP3++CgoKVFpaqrffflvjx4/XkCFD1LdvX0lScnKyevfurQcffFD79+/Xe++9p9mzZys9PV1Op1OSNHnyZH3++eeaMWOGDh8+rBdeeEFvvPGGpk6dGuopAQCAJirkQWf58uWqrKzUXXfdpZiYGHtbs2aNJMnhcGjLli1KTk5Wz549NW3aNI0aNUrvvPOO3Ufz5s21YcMGNW/eXB6PRw888IDGjx+vefPm2TXx8fHauHGjcnNz1a9fPy1evFgvvfQSHy0HAAC2MMuyrMYeRGMJBAIKDw9XZWVlyN/G6jxrY0j7w+WVLkhp7CEAABpYXV6/+VtXAADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYBB0AAGAsgg4AADAWQQcAABiLoAMAAIxF0AEAAMYi6AAAAGMRdAAAgLFaNPYAgJ+j86yNjT2EOitdkNLYQwCAXwyu6AAAAGMRdAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAwFkEHAAAYi6ADAACMRdABAADGatHYAwB+aTrP2tjYQ/hJShekNPYQAKDOuKIDAACMRdABAADGIugAAABjEXQAAICxCDoAAMBYTT7oLFu2TJ07d1bLli2VmJiojz76qLGHBAAArhFNOuisWbNGGRkZevbZZ/Xxxx+rX79+8nq9OnbsWGMPDQAAXAOadNBZsmSJJk6cqIcffli9e/fWihUr1Lp1a61cubKxhwYAAK4BTfYLA8+fP6+CggJlZmbax5o1a6akpCT5fL7LnlNVVaWqqip7v7KyUpIUCARCPr6aqrMh7xNoTB2nrm3sIdRZ0VxvYw8BQD2ofd22LOuqtU026PzjH//QxYsXFR0dHXQ8Ojpahw8fvuw52dnZmjt37iXH4+Li6mWMABpX+J8bewQA6tOpU6cUHh7+gzVNNuj8FJmZmcrIyLD3a2pqdOLECbVv315hYWE/u/9AIKC4uDh9+eWXcrlcP7s//DDWu2Gx3g2HtW5YrHfDCsV6W5alU6dOKTY29qq1TTboREZGqnnz5qqoqAg6XlFRIbfbfdlznE6nnE5n0LGIiIiQj83lcvEfSwNivRsW691wWOuGxXo3rJ+73le7klOryd6M7HA4NGjQIOXl5dnHampqlJeXJ4/H04gjAwAA14ome0VHkjIyMjRhwgTdcsstuvXWW/XnP/9ZZ86c0cMPP9zYQwMAANeAJh10Ro8erb///e+aM2eO/H6/+vfvr02bNl1yg3JDcTqdevbZZy95ewz1g/VuWKx3w2GtGxbr3bAaer3DrB/z2SwAAIAmqMneowMAAHA1BB0AAGAsgg4AADAWQQcAABiLoBNCy5YtU+fOndWyZUslJibqo48+auwhNTlZWVkKCwsL2nr27Gm3f/vtt0pPT1f79u11ww03aNSoUZd8aWRZWZlSUlLUunVrRUVFafr06aqurm7oqVyT8vPzdffddys2NlZhYWFav359ULtlWZozZ45iYmLUqlUrJSUl6bPPPguqOXHihMaNGyeXy6WIiAilpaXp9OnTQTUHDhzQ4MGD1bJlS8XFxWnhwoX1PbVrztXW+qGHHrrkd3348OFBNaz1j5edna1f//rXatOmjaKiopSamqri4uKgmlA9f2zfvl0DBw6U0+lUt27dlJOTU9/Tu6b8mLW+6667Lvn9njx5clBNg621hZBYvXq15XA4rJUrV1oHDx60Jk6caEVERFgVFRWNPbQm5dlnn7Vuvvlmq7y83N7+/ve/2+2TJ0+24uLirLy8PGvv3r3WbbfdZt1+++12e3V1tdWnTx8rKSnJ2rdvn/Xuu+9akZGRVmZmZmNM55rz7rvvWk8//bT15ptvWpKsdevWBbUvWLDACg8Pt9avX2/t37/fuueee6z4+Hjr3Llzds3w4cOtfv36Wbt27bI++OADq1u3btaYMWPs9srKSis6OtoaN26cVVRUZL3++utWq1atrP/8z/9sqGleE6621hMmTLCGDx8e9Lt+4sSJoBrW+sfzer3WK6+8YhUVFVmFhYXWyJEjrY4dO1qnT5+2a0Lx/PH5559brVu3tjIyMqxDhw5Zf/3rX63mzZtbmzZtatD5NqYfs9Z33nmnNXHixKDf78rKSru9IdeaoBMit956q5Wenm7vX7x40YqNjbWys7MbcVRNz7PPPmv169fvsm0nT560rrvuOmvt2rX2sU8//dSSZPl8PsuyvntxadasmeX3++2a5cuXWy6Xy6qqqqrXsTc1//fFt6amxnK73daiRYvsYydPnrScTqf1+uuvW5ZlWYcOHbIkWXv27LFr/ud//scKCwuz/vd//9eyLMt64YUXrLZt2wat98yZM60ePXrU84yuXVcKOvfee+8Vz2Gtf55jx45Zkqz333/fsqzQPX/MmDHDuvnmm4Mea/To0ZbX663vKV2z/u9aW9Z3Qefxxx+/4jkNuda8dRUC58+fV0FBgZKSkuxjzZo1U1JSknw+XyOOrGn67LPPFBsbqy5dumjcuHEqKyuTJBUUFOjChQtB69yzZ0917NjRXmefz6eEhISgL430er0KBAI6ePBgw06kiSkpKZHf7w9a3/DwcCUmJgatb0REhG655Ra7JikpSc2aNdPu3bvtmiFDhsjhcNg1Xq9XxcXF+uabbxpoNk3D9u3bFRUVpR49eujRRx/V8ePH7TbW+ueprKyUJLVr105S6J4/fD5fUB+1Nb/k5/r/u9a1Vq1apcjISPXp00eZmZk6e/as3daQa92kvxn5WvGPf/xDFy9evOQbmaOjo3X48OFGGlXTlJiYqJycHPXo0UPl5eWaO3euBg8erKKiIvn9fjkcjkv+EGt0dLT8fr8kye/3X/bfobYNV1a7Ppdbv++vb1RUVFB7ixYt1K5du6Ca+Pj4S/qobWvbtm29jL+pGT58uO677z7Fx8fr6NGjeuqppzRixAj5fD41b96ctf4Zampq9MQTT+g3v/mN+vTpI0khe/64Uk0gENC5c+fUqlWr+pjSNetyay1JY8eOVadOnRQbG6sDBw5o5syZKi4u1ptvvimpYdeaoINryogRI+yf+/btq8TERHXq1ElvvPHGL+4JBGa7//777Z8TEhLUt29fde3aVdu3b9ewYcMacWRNX3p6uoqKirRjx47GHorxrrTWkyZNsn9OSEhQTEyMhg0bpqNHj6pr164NOkbeugqByMhINW/e/JK79ysqKuR2uxtpVGaIiIjQTTfdpCNHjsjtduv8+fM6efJkUM3319ntdl/236G2DVdWuz4/9Hvsdrt17NixoPbq6mqdOHGCf4OfqUuXLoqMjNSRI0cksdY/1ZQpU7RhwwZt27ZNN954o308VM8fV6pxuVy/uP8Zu9JaX05iYqIkBf1+N9RaE3RCwOFwaNCgQcrLy7OP1dTUKC8vTx6PpxFH1vSdPn1aR48eVUxMjAYNGqTrrrsuaJ2Li4tVVlZmr7PH49Enn3wS9AKRm5srl8ul3r17N/j4m5L4+Hi53e6g9Q0EAtq9e3fQ+p48eVIFBQV2zdatW1VTU2M/kXk8HuXn5+vChQt2TW5urnr06PGLfSvlx/jqq690/PhxxcTESGKt68qyLE2ZMkXr1q3T1q1bL3lLL1TPHx6PJ6iP2ppf0nP91db6cgoLCyUp6Pe7wda6Trcu44pWr15tOZ1OKycnxzp06JA1adIkKyIiIuiOclzdtGnTrO3bt1slJSXWhx9+aCUlJVmRkZHWsWPHLMv67uOhHTt2tLZu3Wrt3bvX8ng8lsfjsc+v/chicnKyVVhYaG3atMnq0KEDHy//f06dOmXt27fP2rdvnyXJWrJkibVv3z7riy++sCzru4+XR0REWG+99ZZ14MAB6957773sx8sHDBhg7d6929qxY4fVvXv3oI88nzx50oqOjrYefPBBq6ioyFq9erXVunXrX9xHnn9orU+dOmU9+eSTls/ns0pKSqwtW7ZYAwcOtLp37259++23dh+s9Y/36KOPWuHh4db27duDPtJ89uxZuyYUzx+1H3mePn269emnn1rLli37xX28/GprfeTIEWvevHnW3r17rZKSEuutt96yunTpYg0ZMsTuoyHXmqATQn/961+tjh07Wg6Hw7r11lutXbt2NfaQmpzRo0dbMTExlsPhsH71q19Zo0ePto4cOWK3nzt3zvrDH/5gtW3b1mrdurX129/+1iovLw/qo7S01BoxYoTVqlUrKzIy0po2bZp14cKFhp7KNWnbtm2WpEu2CRMmWJb13UfMn3nmGSs6OtpyOp3WsGHDrOLi4qA+jh8/bo0ZM8a64YYbLJfLZT388MPWqVOngmr2799v3XHHHZbT6bR+9atfWQsWLGioKV4zfmitz549ayUnJ1sdOnSwrrvuOqtTp07WxIkTL/kfI9b6x7vcWkuyXnnlFbsmVM8f27Zts/r37285HA6rS5cuQY/xS3C1tS4rK7OGDBlitWvXznI6nVa3bt2s6dOnB32PjmU13FqH/b9BAwAAGId7dAAAgLEIOgAAwFgEHQAAYCyCDgAAMBZBBwAAGIugAwAAjEXQAQAAxiLoAAAAYxF0AACAsQg6AADAWAQdAABgLIIOAAAw1v8H/s0ghUd3YNgAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def pad_sequence(reviews_tokenized, sequence_length):\n",
    "    ''' returns the tokenized review sequences padded with 0's or truncated to the sequence_length.\n",
    "    '''\n",
    "    padded_reviews = np.zeros((len(reviews_tokenized), sequence_length), dtype = int)\n",
    "    \n",
    "    for idx, review in enumerate(reviews_tokenized):\n",
    "        review_len = len(review)\n",
    "        \n",
    "        if review_len <= sequence_length:\n",
    "            zeroes = list(np.zeros(sequence_length-review_len))\n",
    "            new_sequence = zeroes+review\n",
    "        elif review_len > sequence_length:\n",
    "            new_sequence = review[0:sequence_length]\n",
    "        \n",
    "        padded_reviews[idx,:] = np.array(new_sequence)\n",
    "    \n",
    "    return padded_reviews\n",
    "\n",
    "sequence_length = 512\n",
    "padded_reviews = pad_sequence(reviews_tokenized=reviews_tokenized, sequence_length=sequence_length)\n",
    "\n",
    "plt.hist(reviews_len);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_val_split = 0.75\n",
    "train_X = padded_reviews[:int(train_val_split*len(padded_reviews))]\n",
    "train_y = encoded_label_list[:int(train_val_split*len(padded_reviews))]\n",
    "validation_X = padded_reviews[int(train_val_split*len(padded_reviews)):]\n",
    "validation_y = encoded_label_list[int(train_val_split*len(padded_reviews)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "## If while training, you get a runtime error that says: \"RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long\".\n",
    "## simply uncomment run the following lines of code additionally\n",
    "# train_X = train_X.astype('int64')\n",
    "# train_y = train_y.astype('int64')\n",
    "# validation_X = validation_X.astype('int64')\n",
    "# validation_y = validation_y.astype('int64')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate torch datasets\n",
    "train_dataset = TensorDataset(torch.from_numpy(train_X).to(device), torch.from_numpy(train_y).to(device))\n",
    "validation_dataset = TensorDataset(torch.from_numpy(validation_X).to(device), torch.from_numpy(validation_y).to(device))\n",
    "\n",
    "\n",
    "batch_size = 32\n",
    "# torch dataloaders (shuffle data)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example Input size:  torch.Size([32, 512])\n",
      "Example Input:\n",
      " tensor([[    0,     0,     0,  ...,    12, 32364,  1943],\n",
      "        [    0,     0,     0,  ...,    55,  1469,  2213],\n",
      "        [    0,     0,     0,  ...,    55,   275,   142],\n",
      "        ...,\n",
      "        [    0,     0,     0,  ...,   285,   611,   142],\n",
      "        [    0,     0,     0,  ...,   224,    14,    73],\n",
      "        [    0,     0,     0,  ...,   691,     8,   142]])\n",
      "\n",
      "Example Output size:  torch.Size([32])\n",
      "Example Output:\n",
      " tensor([1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1.])\n"
     ]
    }
   ],
   "source": [
    "# get a batch of train data\n",
    "train_data_iter = iter(train_dataloader)\n",
    "X_example, y_example = next(train_data_iter)\n",
    "print('Example Input size: ', X_example.size()) # batch_size, seq_length\n",
    "print('Example Input:\\n', X_example)\n",
    "print()\n",
    "print('Example Output size: ', y_example.size()) # batch_size\n",
    "print('Example Output:\\n', y_example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n"
     ]
    }
   ],
   "source": [
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_dimension, embedding_dimension, hidden_dimension, output_dimension):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = nn.Embedding(input_dimension, embedding_dimension)  \n",
    "        self.rnn_layer = nn.RNN(embedding_dimension, hidden_dimension, num_layers=1)\n",
    "        self.fc_layer = nn.Linear(hidden_dimension, output_dimension)\n",
    "        \n",
    "    def forward(self, sequence):\n",
    "        # sequence shape = (sequence_length, batch_size)\n",
    "        embedding = self.embedding_layer(sequence)  \n",
    "        # embedding shape = [sequence_length, batch_size, embedding_dimension]\n",
    "        output, hidden_state = self.rnn_layer(embedding)\n",
    "        # output shape = [sequence_length, batch_size, hidden_dimension]\n",
    "        # hidden_state shape = [1, batch_size, hidden_dimension]\n",
    "        final_output = self.fc_layer(hidden_state[-1,:,:].squeeze(0))      \n",
    "        return final_output\n",
    "    \n",
    "input_dimension = len(vocab_to_token)+1 # +1 to account for padding\n",
    "embedding_dimension = 100\n",
    "hidden_dimension = 32\n",
    "output_dimension = 1\n",
    "\n",
    "rnn_model = RNN(input_dimension, embedding_dimension, hidden_dimension, output_dimension)\n",
    "\n",
    "optim = torch.optim.Adam(rnn_model.parameters())\n",
    "loss_func = nn.BCEWithLogitsLoss()\n",
    "\n",
    "rnn_model = rnn_model.to(device)\n",
    "loss_func = loss_func.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy_metric(predictions, ground_truth):\n",
    "    \"\"\"\n",
    "    Returns 0-1 accuracy for the given set of predictions and ground truth\n",
    "    \"\"\"\n",
    "    # round predictions to either 0 or 1\n",
    "    rounded_predictions = torch.round(torch.sigmoid(predictions))\n",
    "    success = (rounded_predictions == ground_truth).float() #convert into float for division \n",
    "    accuracy = success.sum() / len(success)\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, dataloader, optim, loss_func):\n",
    "    loss = 0\n",
    "    accuracy = 0\n",
    "    model.train()\n",
    "    \n",
    "    for sequence, sentiment in dataloader:\n",
    "        optim.zero_grad()     \n",
    "        preds = model(sequence.T).squeeze()\n",
    "        \n",
    "        loss_curr = loss_func(preds, sentiment)\n",
    "        accuracy_curr = accuracy_metric(preds, sentiment)\n",
    "        \n",
    "        loss_curr.backward()\n",
    "        optim.step()\n",
    "        \n",
    "        loss += loss_curr.item()\n",
    "        accuracy += accuracy_curr.item()\n",
    "        \n",
    "    return loss/len(dataloader), accuracy/len(dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate(model, dataloader, loss_func):\n",
    "    loss = 0\n",
    "    accuracy = 0\n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for sequence, sentiment in dataloader:\n",
    "            \n",
    "            preds = model(sequence.T).squeeze()\n",
    "            \n",
    "            loss_curr = loss_func(preds, sentiment)   \n",
    "            accuracy_curr = accuracy_metric(preds, sentiment)\n",
    "\n",
    "            loss += loss_curr.item()\n",
    "            accuracy += accuracy_curr.item()\n",
    "        \n",
    "    return loss/len(dataloader), accuracy/len(dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch number: 1 | time elapsed: 76.7141444683075s\n",
      "training loss: 0.620 | training accuracy: 66.97%\n",
      "validation loss: 0.935 |  validation accuracy: 31.85%\n",
      "\n",
      "epoch number: 2 | time elapsed: 87.61576223373413s\n",
      "training loss: 0.533 | training accuracy: 74.32%\n",
      "validation loss: 0.859 |  validation accuracy: 45.87%\n",
      "\n",
      "epoch number: 3 | time elapsed: 84.97845983505249s\n",
      "training loss: 0.499 | training accuracy: 76.05%\n",
      "validation loss: 1.104 |  validation accuracy: 2.30%\n",
      "\n",
      "epoch number: 4 | time elapsed: 70.93470454216003s\n",
      "training loss: 0.596 | training accuracy: 68.49%\n",
      "validation loss: 1.082 |  validation accuracy: 15.71%\n",
      "\n",
      "epoch number: 5 | time elapsed: 73.70327711105347s\n",
      "training loss: 0.558 | training accuracy: 71.45%\n",
      "validation loss: 1.042 |  validation accuracy: 25.72%\n",
      "\n",
      "epoch number: 6 | time elapsed: 82.68725538253784s\n",
      "training loss: 0.520 | training accuracy: 74.38%\n",
      "validation loss: 1.013 |  validation accuracy: 35.74%\n",
      "\n",
      "epoch number: 7 | time elapsed: 91.2055242061615s\n",
      "training loss: 0.465 | training accuracy: 77.90%\n",
      "validation loss: 0.934 |  validation accuracy: 48.88%\n",
      "\n",
      "epoch number: 8 | time elapsed: 101.68787813186646s\n",
      "training loss: 0.385 | training accuracy: 83.21%\n",
      "validation loss: 0.736 |  validation accuracy: 66.11%\n",
      "\n",
      "epoch number: 9 | time elapsed: 94.43510675430298s\n",
      "training loss: 0.323 | training accuracy: 86.82%\n",
      "validation loss: 0.596 |  validation accuracy: 74.29%\n",
      "\n",
      "epoch number: 10 | time elapsed: 81.06615042686462s\n",
      "training loss: 0.299 | training accuracy: 88.32%\n",
      "validation loss: 1.023 |  validation accuracy: 57.65%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 10\n",
    "best_validation_loss = float('inf')\n",
    "\n",
    "for ep in range(num_epochs):\n",
    "\n",
    "    time_start = time.time()\n",
    "    \n",
    "    training_loss, train_accuracy = train(rnn_model, train_dataloader, optim, loss_func)\n",
    "    validation_loss, validation_accuracy = validate(rnn_model, validation_dataloader, loss_func)\n",
    "    \n",
    "    time_end = time.time()\n",
    "    time_delta = time_end - time_start  \n",
    "    \n",
    "    if validation_loss < best_validation_loss:\n",
    "        best_validation_loss = validation_loss\n",
    "        torch.save(rnn_model.state_dict(), 'rnn_model.pt')\n",
    "    \n",
    "    print(f'epoch number: {ep+1} | time elapsed: {time_delta}s')\n",
    "    print(f'training loss: {training_loss:.3f} | training accuracy: {train_accuracy*100:.2f}%')\n",
    "    print(f'validation loss: {validation_loss:.3f} |  validation accuracy: {validation_accuracy*100:.2f}%')\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sentiment_inference(model, sentence):\n",
    "    model.eval()\n",
    "    \n",
    "    # text transformations\n",
    "    sentence = sentence.lower()\n",
    "    sentence = ''.join([c for c in sentence if c not in punctuation])\n",
    "    tokenized = [vocab_to_token.get(token, 0) for token in sentence.split()]\n",
    "    tokenized = np.pad(tokenized, (512-len(tokenized), 0), 'constant')\n",
    "    \n",
    "    # model inference\n",
    "    model_input = torch.LongTensor(tokenized).to(device)\n",
    "    model_input = model_input.unsqueeze(1)\n",
    "    pred = torch.sigmoid(model(model_input))\n",
    "    \n",
    "    return pred.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.10942309349775314\n",
      "0.33907952904701233\n",
      "0.9813206791877747\n",
      "0.8531726002693176\n"
     ]
    }
   ],
   "source": [
    "print(sentiment_inference(rnn_model, \"This film is horrible\"))\n",
    "print(sentiment_inference(rnn_model, \"Director tried too hard but this film is bad\"))\n",
    "print(sentiment_inference(rnn_model, \"This film will be houseful for weeks\"))\n",
    "print(sentiment_inference(rnn_model, \"I just really loved the movie\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
